import wave
import pyaudio
import numpy as np
import librosa  
import matplotlib.pyplot as plt
import cv2
import MyDenseNet

# 定义数据流块
CHUNK = 1024
FORMAT = pyaudio.paInt16
CHANNELS = 1
RATE = 11025
WAVE_OUTPUT_FILENAME = "temp.wav"

#录音
def SoundRecoding(sec = 5):
    # 录音时间
    RECORD_SECONDS = sec
    
    # 创建PyAudio对象
    p = pyaudio.PyAudio()
    
    # 打开数据流
    stream = p.open(format=FORMAT,
                    channels=CHANNELS,
                    rate=RATE,
                    input=True,
                    frames_per_buffer=CHUNK)
    
    print("* recording")
    
    # 开始录音
    frames = []
    for i in range(0, int(RATE / CHUNK * RECORD_SECONDS)):
        data = stream.read(CHUNK)
        frames.append(data)
    
    print("* done recording")

    # 停止数据流
    stream.stop_stream()
    stream.close()
    
    # 关闭PyAudio
    p.terminate()

    # 写入录音文件
    wf = wave.open(WAVE_OUTPUT_FILENAME, 'wb')
    wf.setnchannels(CHANNELS)
    wf.setsampwidth(p.get_sample_size(FORMAT))
    wf.setframerate(RATE)
    wf.writeframes(b''.join(frames))
    wf.close()
    return WAVE_OUTPUT_FILENAME

    
'''
把声波序列生成对应的语谱图并保存
groupNp 声波序列
NFFT
framerate 采样率
framesize 抽样数
overlapSize 帧移数量
'''
def ExportImg(groupNp, NFFT, framerate, framesize, overlapSize):
    spectrum,freqs,ts,fig = plt.specgram(groupNp,
                                         NFFT = NFFT,
                                         Fs = framerate,
                                         window=np.hanning(M = framesize),
                                         noverlap=overlapSize,
                                         mode='default',
                                         scale_by_freq=True,
                                         sides='onesided',
                                         scale='dB',
                                         xextent=None)#绘制频谱图
    
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())
    plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
    plt.margins(0,0)
    plt.axis('off')
    plt.savefig("tmp.png")  
    '''   
    plt.ylabel('Frequency')
    plt.xlabel('Time')
    plt.title("Spectrogram")
    plt.savefig("tmp1.png")  
    '''   
    plt.show()
    src = cv2.imread("tmp.png")
    reshape = np.shape(spectrum)
    imageChange = cv2.resize(src, (reshape[1], reshape[0]))
    #imageChange = cv2.cvtColor(imageChange, cv2.COLOR_RGB2GRAY)
    #rgb  rg=黄色  判断r b的大小来区分前景背景
    return imageChange
     
'''
wav文件生成语谱图
对wav数据筛选，对声音强度高的1.5S截取生成声纹图  对强度低的 持续1.5S截取生成背景噪音语谱图
nplist wav序列
framelength 帧时长 默认 0.025秒
framerate wav采样率 默认11025
'''       
def WavToImg(nplist, framelength = 0.025, framerate = 11025):
    framesize = framelength*framerate #每帧点数 N = t*fs,通常情况下值为256或512,要与NFFT相等\
                                    #而NFFT最好取2的整数次方,即framesize最好取的整数次方
 
    #找到与当前framesize最接近的2的正整数次方
    nfftdict = {}
    lists = [32,64,128,256,512,1024]
    for i in lists:
        nfftdict[i] = abs(framesize - i)
    sortlist = sorted(nfftdict.items(), key=lambda x: x[1])#按与当前framesize差值升序排列
    framesize = int(sortlist[0][0])#取最接近当前framesize的那个2的正整数次方值为新的framesize
     
    NFFT = framesize #NFFT必须与时域的点数framsize相等，即不补零的FFT
    overlapSize = 1.0/2 * framesize #重叠部分采样点数overlapSize约为每帧点数的1/3~1/2
    overlapSize = int(round(overlapSize))#取整
    print("帧长为{},帧叠为{},傅里叶变换点数为{}".format(framesize,overlapSize,NFFT))

    data = ExportImg(nplist, NFFT, framerate, framesize, overlapSize)
  
    return data

#识别某个文件 返回标签编号和特征向量
#a = RecognizeVoice('DataSet\\0001\\BAC009S0764W0144.wav')
#a = RecognizeVoice('DataSet\\0002\\BAC009S0765W0196.wav')
#RecognizeVoice('DataSet\\0003\\BAC009S0766W0230.wav')
#RecognizeVoice('DataSet\\0004\\BAC009S0767W0208.wav')
def RecognizeVoice(fileName, printOut = False):
    y, s = librosa.load(fileName, sr=11025)
    imgData = WavToImg(y)
    res = MyDenseNet.RecognizeData([imgData])
    #背景去掉、避免声音背景站空过多造成识别为背景音
    label = np.argmax(res[0][0])
    if label == 0:
        print('未识别到有效声音')
        return None
    print('识别到标签：' + str(label))
    print('特征向量')
    print(res[1])
    margin = 5
    if printOut:
        minKey = ''
        minValue = -1.0
        for k,v in dicRegistList.items():
            dis = np.sum(np.square(res[1] - v))
            likeRate = (margin - dis) / margin
            if likeRate < 0:
                likeRate = 0
            print('对比 ' + k + ' 相似度:' + str(likeRate))
            if minValue == -1:
                minValue = dis
                minKey = k
            if minValue > dis:
                minValue = dis
                minKey = k
        if minValue > -1.0:
            print('识别名称：' + minKey)
    return res

#录音并识别
def RecordAndRecognize():
    tmp_file = SoundRecoding()
    label = RecognizeVoice(tmp_file, True)
    return label

dicRegistList = {}

#RegistVoice('DataSet\\0001\\BAC009S0764W0144.wav')
#RegistVoice('DataSet\\0002\\BAC009S0765W0196.wav')
#RegistVoice('DataSet\\0003\\BAC009S0766W0230.wav')
#RegistVoice('DataSet\\0004\\BAC009S0767W0208.wav')
def RegistVoice(fileName = ''):
    strName = input('请输入一个名称')
    if fileName == '':
        fileName = SoundRecoding()
    label = RecognizeVoice(fileName)
    if label is not None:
        dicRegistList[strName] = label[1]
    
if __name__ == '__main__':
    while True:
        cmd = input('选择指令：0注册、1识别、2退出')
        if cmd == '0':
            RegistVoice()
        if cmd == '1':
            RecordAndRecognize()
        if cmd == '2':
            break
        
    